import sys

sys.path.append("../")
from highway_env.envs.highway_env import HighwayEnv
import matplotlib.pyplot as plt

screen_width, screen_height = 84, 84

HighwayEnv.DEFAULT_CONFIG["offscreen_rendering"] = True
HighwayEnv.DEFAULT_CONFIG["observation"] =  {
            "type": "GrayscaleObservation",
            "weights": [0.2989, 0.5870, 0.1140],  #weights for RGB conversion
            "stack_size": 4,
            "observation_shape": (screen_width, screen_height)
        }

fig, ax = plt.subplots(ncols=2, nrows=2)
fig.set_figheight(10)
fig.set_figwidth(10)
print("########################### Initial Stack ###########################")
for i in range(4):
    row = int(i//2)
    col = int(i%2)
    ax[row, col].imshow(init_stack[:,:,i], cmap=plt.get_cmap('gray'))
plt.show()


for i in range(3):
    out = env.step(1)
    print("########################### New state ###########################".format(i+1))
    fig, ax = plt.subplots(ncols=2, nrows=2)
    fig.set_figheight(10)
    fig.set_figwidth(10)
    for j in range(4):
        row = int(j//2)
        col = int(j%2)
        ax[row, col].imshow(out[0][:,:,j], cmap=plt.get_cmap('gray'))
        ax[row, col].set_title("frame #{}".format(j))
    plt.show()